# -*- coding: utf-8 -*-
"""
@author: jianping zhang

@file: bert_text_classification.py

@date: 2021/6/24

@desc: Bert文本分类模型
"""

import torch
import torch.nn as nn
from transformers import BertModel

class Bert_Text_Classfication(nn.Module):

    def __init__(self, bert_config, hidden_dim, n_classes):
        super(Bert_Text_Classfication, self).__init__()
        self.hidden_dim = hidden_dim
        self.bert = BertModel.from_pretrained(bert_config)
        self.linear = nn.Linear(hidden_dim, n_classes)

    def forward(self, sentence, attention_mask=None):
        embeds = self.bert(sentence, attention_mask=attention_mask)
        embeds_cls = embeds[0][:, 0, :] # CLS的向量输出
        out = self.linear(embeds_cls)
        return out # [batch_size, num_classes]

